{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to easily train a V-Net or any other model for lung cancer segmentation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is supplemental notebook for medium post in [Data Analysis Center blog](https://medium.com/data-analysis-center)\n",
    "\n",
    "We suggest using GPU, tested for NVIDIA GTX 1080. Note, that network is quite huge and takes lots of memory.\n",
    "\n",
    "V-Net is popular CNN architecture for segmentation on volumetric images, such as CT-scans, see [Milletari et al.](https://arxiv.org/abs/1606.04797)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sys\n",
    "import tensorflow as tf\n",
    "sys.path.append('../')\n",
    "\n",
    "from radio.batchflow import Pipeline, FilesIndex, Dataset, F\n",
    "from radio import CTImagesMaskedBatch as CTIMB\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Thu Dec 14 11:57:52 2017       \r\n",
      "+-----------------------------------------------------------------------------+\r\n",
      "| NVIDIA-SMI 375.26                 Driver Version: 375.26                    |\r\n",
      "|-------------------------------+----------------------+----------------------+\r\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\r\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\r\n",
      "|===============================+======================+======================|\r\n",
      "|   0  GeForce GTX 1080    Off  | 0000:02:00.0     Off |                  N/A |\r\n",
      "| 23%   49C    P8    15W / 200W |   4487MiB /  8113MiB |      0%      Default |\r\n",
      "+-------------------------------+----------------------+----------------------+\r\n",
      "|   1  GeForce GTX 1080    Off  | 0000:03:00.0     Off |                  N/A |\r\n",
      "| 30%   52C    P8    15W / 200W |      2MiB /  8112MiB |      0%      Default |\r\n",
      "+-------------------------------+----------------------+----------------------+\r\n",
      "                                                                               \r\n",
      "+-----------------------------------------------------------------------------+\r\n",
      "| Processes:                                                       GPU Memory |\r\n",
      "|  GPU       PID  Type  Process name                               Usage      |\r\n",
      "|=============================================================================|\r\n",
      "+-----------------------------------------------------------------------------+\r\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Examples in this notebook use [LUNA16 competition dataset](https://luna16.grand-challenge.org/) in MetaImage (mhd/raw) format.\n",
    "\n",
    "You need to specify mask for '\\*.mhd' input files in DIR_LUNA. Here we use unzipped competition dataset, mhd files are stored in subfolders, names of subfolders are taken as ids."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DIR_LUNA = '/notebooks/data/MRT/luna/s*/*.mhd' # Dir with LIDC-IDRI 3D scans"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting up the index and the dataset\n",
    "\n",
    "Index all data and create a Dataset-thing, which conceptually represents all the raw data and let us do the cool thing: iterate data in batches, just like when we are training neural network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = FilesIndex(path=DIR_LUNA, no_ext=True)\n",
    "lunaset = Dataset(index=index, batch_class=CTIMB)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing everything\n",
    "Here we load LUNA dataset, split it to train and test, normalise values to radiologic Hounsfield Units on-the-fly, resize images (to equalize spacing along each axes for different patients), create masks and crop patches around nodules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lunaset.split([0.9, 0.1])  # 90 % goes for training\n",
    "\n",
    "# load annotations into df\n",
    "nodules = pd.read_csv('/notebooks/data/MRT/luna/CSVFILES/annotations.csv')\n",
    "\n",
    "pipeline = (Pipeline()\n",
    "        .load(fmt='raw')  # load scans\n",
    "        .normalize_hu(-1000, 400)  # normalize hu\n",
    "        .fetch_nodules_info(nodules=nodules)  # load nodules locations\n",
    "        .unify_spacing(shape=(128, 256, 256), spacing=(2.5, 2.0, 2.0))  # equalize spacing of different scans\n",
    "        .create_mask()  # create masks\n",
    "        .sample_nodules(nodule_size=(32, 64, 64), batch_size=10, share=0.5)  # sample crops\n",
    "       )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TF-model for segmentation in a loop\n",
    "\n",
    "Very often we run NN models in a simple train loop (see: https://www.tensorflow.org/get_started/mnist/mechanics#train_loop)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Say, we define a toy-model with inputs, targets, predictions and train_step as we usually do in TF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "inputs = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64])\n",
    "targets = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64])\n",
    "\n",
    "session = tf.Session()\n",
    "\n",
    "# oversimplified model\n",
    "reshaped = tf.reshape(inputs, shape=(-1, 32, 64, 64, 1))\n",
    "predictions = tf.reshape(\n",
    "    tf.layers.conv3d(\n",
    "        reshaped, kernel_size=(5, 5, 5), padding='same', filters=1),\n",
    "    shape=(-1, 64, 64))\n",
    "\n",
    "loss = tf.losses.mean_squared_error(labels=targets, predictions=predictions)\n",
    "train_step = tf.train.AdamOptimizer().minimize(loss)\n",
    "\n",
    "session.run(tf.global_variables_initializer())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With RadIO you can flow raw data through our preprocessing pipeline in lazy mode:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lunapipe = (lunaset >> pipeline)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...and then train network on batches:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_ITERS = 100\n",
    "for _ in range(N_ITERS):\n",
    "    batch = lunapipe.next_batch(batch_size=10, n_epochs=None)\n",
    "    session.run(train_step, feed_dict={inputs: batch.images, targets: batch.masks})\n",
    "    print(_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However sometimes it is inconvenient and bulky, for **TF** and **keras** training is done via pipelines.\n",
    "Consider that we decided to train V-Net in TF, not an unusual thing these days:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from radio.batchflow.models.tf.vnet import VNet\n",
    "from radio.models.tf.losses import dice_loss\n",
    "from radio import batchflow as bf\n",
    "from radio.models.metrics import dice"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Models from [dataset](http://github.com/analysiscenter/dataset) requires specifying input config and model config. In model config we will slightly change the end layers of network to have sigmoid activation and predict only 1 class (cancerous/non-cancerous). Again, training is done on patches of (32, 64, 64) size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "inputs_config = dict(\n",
    "    images={'shape': (32, 64, 64, 1)},\n",
    "    labels={'name': 'targets', 'shape': (32, 64, 64, 1)}\n",
    ")\n",
    "\n",
    "\n",
    "model_config = dict(\n",
    "    inputs=inputs_config,\n",
    "    optimizer='Adam',\n",
    "    loss=dice_loss,\n",
    "    build=True\n",
    ")\n",
    "\n",
    "model_config['input_block/inputs'] = 'images'\n",
    "model_config['head/num_classes'] = 1\n",
    "model_config['head/layout'] = 'ca'\n",
    "model_config['head/activation'] = tf.nn.sigmoid\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, network's feed dict is directly specified, see [documentation]() for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_pipeline = (\n",
    "    pipeline\n",
    "      .init_model('static', VNet, 'vnet', config=model_config)\n",
    "      .train_model('vnet', feed_dict={ \n",
    "          'images': F(CTIMB.unpack, component='images'),\n",
    "          'labels': F(CTIMB.unpack, component='segmentation_targets')\n",
    "      })\n",
    ") << lunaset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "After compiling, you can train it immediately:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<radio.dataset.dataset.pipeline.Pipeline at 0x7f0f50d69358>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_pipeline.run(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn\n",
    "\n",
    "model = learning_ppl.get_model_by_name('vnet')\n",
    "(\n",
    "    model.train_metrics\n",
    "         .loc[:, ['dice']]\n",
    "         .rolling(16)\n",
    "         .mean()\n",
    "         .plot(figsize=(10, 7))\n",
    ")\n",
    "plt.xlabel('iteration')\n",
    "plt.ylabel('metric')\n",
    "plt.grid(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we want to see how our v-net model predict on the whole patient’s scan, hold on, it’s also super_easy with predict_on_scan action:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "res_pipe = (\n",
    "    Pipeline().load(fmt='raw')  # load scans\n",
    "    .normalize_hu(-1000, 400)  # normalize hu\n",
    "    .fetch_nodules_info(nodules=nodules)  # load nodules locations\n",
    "    .unify_spacing(shape=(128, 256, 256), spacing=(2.5, 2.0, 2.0))\n",
    "    .predict_on_scan(\n",
    "        model_name='vnet',\n",
    "        strides=(32, 64, 64),\n",
    "        batch_size=4,\n",
    "        dim_ordering='channels_last',\n",
    "        y_component='labels')\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch = (lunaset >> res_pipe).next_batch(6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "visualize_batch(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import ipywidgets\n",
    "from ipywidgets import interact"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def visualize(batch):\n",
    "    size = len(batch)\n",
    "    @interact(item=ipywidgets.IntSlider(value=0, min=0, max=size-1),\n",
    "              height=(0.01, 0.99, 0.01))\n",
    "    def visualizer(item, height):\n",
    "        lb, ub = batch.lower_bounds[item], batch.upper_bounds[item]\n",
    "        return plot_arr_slices(height, batch.images[lb: ub, :, :],\n",
    "                               batch.masks[lb: ub, :, :], batch.real_masks[lb: ub, :, :])\n",
    "    return visualizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def plot_arr_slices(height, *arrays, clim=(180, 255)):\n",
    "    fig, axes = plt.subplots(1, len(arrays), figsize=(20, len(arrays)*8))\n",
    "    \n",
    "    for arr, i in zip(arrays, range(len(arrays))):\n",
    "        depth = arr.shape[0]\n",
    "        n_slice = int(depth * height)\n",
    "        \n",
    "        kwargs = dict()\n",
    "#         if np.max(arr) - np.min(arr) > 2.0:\n",
    "#             kwargs.update(clim=clim)\n",
    "#         else:\n",
    "#             kwargs.update(clim=(0, 1))\n",
    "        clim = (180, 255)\n",
    "        axes[i].grid(color='w', linestyle='-', linewidth=0.5)\n",
    "        axes[i].imshow(arr[n_slice], cmap=plt.cm.gray, **kwargs)\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
